# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
from pathlib import Path
from typing import List, Tuple, Union

import jax.numpy as jnp
import numpy as np
import torch
import torch.optim as optim
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, TensorDataset

import secretflow as sf
from secretflow.device import PYU, PYUObject, proxy
from secretflow.utils.communicate import ForwardData
from secretflow_fl.ml.nn.callbacks.attack import AttackCallback


@proxy(PYUObject)
class ExploitAttackTrainer(object):
    def __init__(self, surrogate_model, H_y, y_pri) -> None:
        self.surrogate_model = surrogate_model()
        self.H_y = H_y
        self.y_pri = y_pri

    @staticmethod
    def BCELoss(y_true, y_pred):
        y_true = y_true.view(-1)
        loss_fn = torch.nn.BCELoss()
        binary_crossentropy = loss_fn(y_pred, y_true)
        return binary_crossentropy

    # @staticmethod
    def GradLoss(self, grad_target, X, y_true):
        X.requires_grad_(True)

        # forward pass compute loss
        y_pred = self.surrogate_model(X).view(-1)
        binary_crossentropy = self.BCELoss(y_true, y_pred)
        loss = torch.mean(binary_crossentropy)
        # backward pass compute gradient of x
        self.surrogate_model.zero_grad()  # clear legacy gradient
        loss.backward()

        # get gradients
        gradients = X.grad.data
        gradients = gradients.float()
        # clip gradients
        max_gradient_value = 50.0
        clipped_gradients = torch.clamp(
            gradients, -max_gradient_value, max_gradient_value
        )

        # compute gradient loss
        gradient_loss_value = torch.norm(
            (grad_target - clipped_gradients * y_true.shape[0]), p=1
        )
        return gradient_loss_value

    @staticmethod
    def KLLoss(y_pri, y_pred):
        eps = 1e-9
        y_pred_mean = torch.mean(
            torch.stack([1 - y_pred, y_pred], dim=1), dim=0, keepdim=True
        )
        kl_divergence = -y_pri * torch.log(y_pred_mean / (y_pri + eps) + eps)
        kl_divergence = torch.sum(kl_divergence)

        return kl_divergence

    def custom_loss(
        self,
        X,
        y_true,
        y_pred,
        y_pri,
        H_y,
        grad_target,
        alpha_acc=1,
        alpha_grad=0.001,
        alpha_kl=0.001,
    ):
        H_y = H_y.float()
        y_pri = y_pri.float()
        loss_acc = self.BCELoss(y_true, y_pred) / H_y
        loss_grad = self.GradLoss(grad_target, X, y_true)
        loss_kl = self.KLLoss(y_pri, y_pred)
        # total loss, weight could be change if need
        total_loss = alpha_acc * loss_acc + alpha_grad * loss_grad + alpha_kl * loss_kl
        return total_loss

    def get_dataloader(
        self,
        x,
        p_hats=None,
        grads=None,
        batch_size=128,
        mode="train",
    ):
        if mode == "train":
            x_train = x
            y_train = np.squeeze(np.round(p_hats))
            x_train = torch.tensor(
                x, dtype=torch.float32
            )  # If x is not already a tensor, convert it
            y_train = torch.tensor(y_train, dtype=torch.float32)
            grads = torch.tensor(grads, dtype=torch.float32)

            # Create a TensorDataset
            dataset = TensorDataset(x_train, y_train, grads)

            # Create a DataLoader
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

            # Use the dataloader in your training loop

        elif mode == "eval":
            X_test = x
            y_test = np.squeeze(np.round(p_hats))
            X_test = torch.tensor(
                X_test, dtype=torch.float32
            )  # If x is not already a tensor, convert it
            y_test = torch.tensor(y_test, dtype=torch.float32)

            dataset = TensorDataset(X_test, y_test)
            dataloader = DataLoader(dataset, batch_size=batch_size)
        return dataloader

    def train(
        self,
        x_train,
        p_hats,
        grads,
        epochs,
        batch_size,
        alpha_acc=1,
        alpha_grad=0.001,
        alpha_kl=0.001,
    ):
        # optimizer
        optimizer = optim.Adam(self.surrogate_model.parameters(), lr=0.01)

        # Dataloader
        dataloader = self.get_dataloader(
            x=x_train,
            p_hats=p_hats,
            grads=grads,
            batch_size=batch_size,
            mode="train",
        )

        # start training
        for epoch in range(epochs):
            total_loss = 0.0
            correct_predictions = 0
            total_samples = 0

            y_true_list = []
            y_pred_list = []
            for epoch in range(epochs):
                total_loss = 0.0
                correct_predictions = 0
                total_samples = 0
                y_true_list = []
                y_pred_list = []
                for X, y, grad_target in dataloader:
                    optimizer.zero_grad()
                    y_pred = self.surrogate_model(X)
                    y_pred = y_pred.view(-1)
                    loss = self.custom_loss(
                        X,
                        y,
                        y_pred,
                        self.y_pri,
                        self.H_y,
                        grad_target,
                        alpha_acc=alpha_acc,
                        alpha_grad=alpha_grad,
                        alpha_kl=alpha_kl,
                    )
                    loss.backward()
                    optimizer.step()

                    total_loss += loss.item()
                    y_pred_list.append(y_pred.detach())
                    y_true_list.append(torch.round(y))

                    predicted_labels = torch.round(y_pred)
                    correct_predictions += (predicted_labels == y).sum().item()
                    total_samples += X.size(0)
            # average_loss and acc
            average_loss = total_loss / total_samples

            accuracy = correct_predictions / total_samples

            logging.info(
                f"Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}, AUC: {accuracy:.4f}"
            )

    def evaluate(
        self,
        x_test,
        p_hats,
        batch_size=128,
    ):
        # Dataloader
        dataloader = self.get_dataloader(
            x=x_test,
            p_hats=p_hats,
            batch_size=batch_size,
            mode="eval",
        )

        correct_predictions = 0
        total_samples = 0
        y_true_list = []
        y_pred_list = []
        for X, y in dataloader:
            y_pred = self.surrogate_model(X)
            y_pred = y_pred.view(-1)  # Reshape y_pred to a 1D tensor
            y_pred_list.append(
                y_pred.detach().cpu().numpy()
            )  # Store predictions as NumPy array for AUC computation
            # Compute batch accuracy
            predicted_labels = torch.round(y_pred)  # Round predictions to 0 or 1
            y = y.float()  # Cast y to float32
            y_true_list.append(
                torch.round(y).cpu().numpy()
            )  # Store true labels as NumPy array for AUC computation
            correct_predictions += (
                (predicted_labels == y).sum().item()
            )  # Sum the number of correct predictions
            total_samples += X.shape[0]

        # Compute AUC value using scikit-learn (since PyTorch doesn't have a built-in AUC metric)
        y_true = np.concatenate(y_true_list)
        y_pred_np = np.concatenate(y_pred_list)
        if np.all(y_true == 0) or np.all(y_true == 1):
            logging.warning("All y_true are 0 or 1, skipping AUC calculation.")
            auc_value = -1  # give an illegal value
        else:
            auc_value = roc_auc_score(y_true, y_pred_np)  # AUC computation
        # Compute accuracy
        accuracy = correct_predictions / total_samples
        # print result
        logging.info(f"Test Accuracy: {accuracy:.4f}, AUC: {auc_value:.4f}")
        metrics = {"accuracy": float(accuracy), "auc_value": float(auc_value)}
        return metrics

    def save_model(self, model_path):
        Path(model_path).parent.mkdir(parents=True, exist_ok=True)
        assert model_path is not None, "model path cannot be empty"
        check_point = {
            "surrogate_model": self.surrogate_model.state_dict(),
        }
        torch.save(check_point, model_path)

    def load_model(self, model_path):
        assert model_path is not None, "model path cannot be empty"
        checkpoint = torch.load(model_path)
        self.surrogate_model.load_state_dict(checkpoint["surrogate_model"])


class ExploitAttack(AttackCallback):
    def __init__(
        self,
        attack_party: PYU,
        batch_size=32,
        epochs=1,
        surrogate_model=None,
        H_y=None,
        y_pri=None,
        alpha_acc=1,
        alpha_grad=0.001,
        alpha_kl=0.001,
        **params,
    ):
        """
        ExplioitAttack in torch backend.

        Attributes:
            attack_party (PYU): An instance representing the attacking party, which contains the information or methods required to conduct the attack.
            surrogate_model: A machine learning model that the attacker aims to exploit. This model is supposed to mimic the target model's behavior and is used to generate adversarial examples or identify vulnerabilities. If not provided, the attack will operate without a specific model.
            epoch (int): The number of epochs for which the attack will be conducted. Default is 5.
            batch_size (int): The size of the batch to be used during the attack generation process. Default is 128.
            H_y: A heuristic or auxiliary information used by the attack algorithm, such as gradients or model outputs. This parameter is model-dependent and should be provided if relevant.
            y_pri: The adversary's prior knowledge about the outputs. It could be used in attacks that require knowledge about the distribution of the labels.
            alpha_acc: weight of accuracy loss in custom loss, default 1.
            alpha_grad: weight of grad loss in custom loss, default 0.001.
            alpha_kl: weight of kl loss in custom loss, default 0.001.
            **params: Arbitrary keyword arguments that could be used to pass additional parameters required for the attack.

        Note:
        This class is intended to be used as a callback within a training loop, where it will be invoked at specified intervals to conduct the attack.
        """
        super().__init__()
        self.grads = []
        self.p_hats = []
        self.eval_p_hats = []
        self.attack_party = attack_party
        self.epochs = epochs
        self.batch_size = batch_size
        self.agg_hiddens = []
        self.eval_hiddens = []
        self.attacker = ExploitAttackTrainer(
            device=attack_party,
            surrogate_model=surrogate_model,
            H_y=H_y,
            y_pri=y_pri,
        )
        self.alpha_acc = alpha_acc
        self.alpha_kl = alpha_kl
        self.alpha_grad = alpha_grad

    @staticmethod
    def convert_to_ndarray(*data: List) -> Union[List[jnp.ndarray], jnp.ndarray]:
        def _convert_to_ndarray(hidden):
            # processing data
            if not isinstance(hidden, jnp.ndarray):
                if isinstance(hidden, torch.Tensor):
                    hidden = jnp.array(hidden.detach().cpu().numpy())
                if isinstance(hidden, np.ndarray):
                    hidden = jnp.array(hidden)
            return hidden

        if isinstance(data, Tuple) and len(data) == 1:
            # The case is after packing and unpacking using PYU, a tuple of length 1 will be obtained, if 'num_return' is not specified to PYU.
            data = data[0]
        if isinstance(data, (List, Tuple)):
            return [_convert_to_ndarray(d) for d in data]
        else:
            return _convert_to_ndarray(data)

    @staticmethod
    def convert_to_tensor(hidden: Union[List, Tuple], backend: str):
        if isinstance(hidden, (List, Tuple)):
            hidden = [torch.Tensor(d) for d in hidden]
        else:
            hidden = torch.Tensor(hidden)
        return hidden

    def on_train_begin(self, logs=None):
        self.steps_per_epoch = sf.reveal(
            self._workers[self.device_y].get_steps_per_epoch()
        )

    def on_agglayer_forward_end(self, hiddens=None):
        victim_status = sf.reveal(self._workers[self.device_y].get_traing_status())
        epoch = victim_status["epoch"]
        stage = victim_status["stage"]
        slmodel_epochs = self.params["epochs"]
        if epoch == slmodel_epochs - 1:
            y_pred = self._workers[self.device_y].predict(hiddens)
            p_hat = self.device_y(self.convert_to_ndarray)(y_pred)
            if stage == "train":
                self.p_hats.append(p_hat.to(self.attack_party))
            else:
                self.eval_p_hats.append(p_hat.to(self.attack_party))

            if isinstance(hiddens, PYUObject):
                hiddens = [hiddens]
            agg_obj = hiddens[0].to(self.attack_party)
            agg_hidden = self.attack_party(
                lambda forward_data: (
                    self.convert_to_ndarray(forward_data.hidden)
                    if isinstance(forward_data, ForwardData)
                    else self.convert_to_ndarray(forward_data)
                )
            )(agg_obj)
            if stage == "train":
                self.agg_hiddens.append(agg_hidden)
            elif stage == "eval":
                self.eval_hiddens.append(agg_hidden)

    def on_agglayer_backward_end(self, gradients):
        victim_status = sf.reveal(self._workers[self.device_y].get_traing_status())
        epoch = victim_status["epoch"]
        slmodel_epochs = self.params["epochs"]
        if epoch == slmodel_epochs - 1:
            grad_client = gradients[self.attack_party]
            grad_client_np = self.attack_party(self.convert_to_ndarray)(grad_client)
            grad_client_np = self.attack_party(
                lambda g: g[0] if isinstance(g, list) else g
            )(grad_client_np)
            self.grads.append(grad_client_np.to(self.attack_party))

    def on_train_end(self, logs=None):
        self.p_hats = self.attack_party(lambda p_hats: np.concatenate(p_hats, axis=0))(
            self.p_hats
        )

        self.grads = self.attack_party(lambda grads: np.concatenate(grads, axis=0))(
            self.grads
        )
        self.agg_hiddens = self.attack_party(
            lambda agg_hiddens: np.concatenate(agg_hiddens, axis=0)
        )(self.agg_hiddens)

        self.attacker.train(
            x_train=self.agg_hiddens,  # agg_hiddens
            p_hats=self.p_hats,
            grads=self.grads,
            epochs=self.epochs,
            batch_size=self.batch_size,
            alpha_acc=self.alpha_acc,
            alpha_kl=self.alpha_kl,
            alpha_grad=self.alpha_grad,
        )

    def get_attack_metrics(self):
        self.eval_hiddens = self.attack_party(
            lambda agg_hiddens: np.concatenate(agg_hiddens, axis=0)
        )(self.eval_hiddens)
        self.eval_p_hats = self.attack_party(
            lambda p_hats: np.concatenate(p_hats, axis=0)
        )(self.eval_p_hats)
        metrics = self.attacker.evaluate(
            x_test=self.eval_hiddens,
            p_hats=self.eval_p_hats,
            batch_size=128,
        )
        return sf.reveal(metrics)
